import copy
from functools import partial
from types import MethodType

import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all

NUM_LAYER = 8
DIM = 4


class MlpModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


def pp_linear_fwd(
    forward,
    data: torch.Tensor = None,
    input_obj: torch.Tensor = None,
    stage_mgr: PipelineStageManager = None,
    model_chunk_id: int = None,
):
    with stage_mgr.switch_model_chunk_id(model_chunk_id):
        if stage_mgr.is_first_stage():
            return {"input_obj": forward(data)}
        elif stage_mgr.is_last_stage():
            return forward(input_obj)
        else:
            return {"input_obj": forward(input_obj)}


def run_pp(
    rank: int,
    world_size: int,
    port: int,
    num_microbatch: int,
    batch_size: int,
    num_model_chunk: int,
):
    """
    This test is to examine the correctness of interleaved 1F1B, compared with torch.
    Be aware it contains some hardcodes.
    """
    colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")

    # create model
    seed_all(1453)
    torch_model = MlpModel().cuda()
    pp_model = copy.deepcopy(torch_model).cuda()

    pg_mesh = ProcessGroupMesh(world_size)
    stage_manager = PipelineStageManager(
        pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
    )
    schedule = InterleavedSchedule(
        stage_manager=stage_manager,
        num_model_chunks=num_model_chunk,
        num_microbatch=num_microbatch,
    )

    sharded_model = torch.nn.ModuleList()
    for idx, sub_model in enumerate(pp_model.layers):
        if idx % world_size == rank:
            sub_model._forward = sub_model.forward
            sub_model.forward = MethodType(
                partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(sharded_model)),
                sub_model._forward,
            )
            sharded_model.append(sub_model.cuda())
    assert len(sharded_model) == num_model_chunk, "num_model_chunk is not correct"

    # create optimizer
    torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5)
    pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1e-5))

    # create data
    seed_all(115)
    input_list = [torch.rand(batch_size, DIM).cuda()]
    dist.all_reduce(input_list[0])

    def criterion(x, *args, **kwargs):
        return (x * x).mean()

    # forward and backward
    torch_output = torch_model(input_list[0])
    torch_loss = criterion(torch_output)
    torch_loss.backward()

    pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True)

    # check loss
    if stage_manager.is_last_stage(ignore_chunk=True):
        assert_close(torch_loss, pp_ret["loss"])

    # check gradients
    for i in range(num_model_chunk):
        idx = world_size * i + rank
        assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
        assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)

    # step
    torch_optimizer.step()
    pp_optimizer.step()
    pp_optimizer.zero_grad()

    # check updated param
    for i in range(num_model_chunk):
        idx = world_size * i + rank
        assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
        assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)

    # forward only
    with torch.no_grad():
        torch_output = torch_model(input_list[0])
        torch_loss = criterion(torch_output)

        pp_ret = schedule.forward_backward_step(
            sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
        )
        if stage_manager.is_last_stage(ignore_chunk=True):
            assert_close(torch_loss, pp_ret["loss"])

        for layer in sharded_model:
            if layer.weight.grad is None:
                assert layer.weight.grad is None and layer.bias.grad is None
            else:
                assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
                assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))


@pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 12])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("num_model_chunk", [2, 4])
@rerun_if_address_is_in_use()
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
    assert NUM_LAYER % num_model_chunk == 0
    spawn(
        run_pp,
        nprocs=NUM_LAYER // num_model_chunk,
        num_microbatch=num_microbatch,
        batch_size=batch_size,
        num_model_chunk=num_model_chunk,
    )


if __name__ == "__main__":
    test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4)
